# Simulating changes in the amount of rooftop solar being produced

# What we want: As solar panel installations increase or decrease...
#  - How do the marginal damages of production change?
#  - How does fuel mix change? 

# That means we have to 
# 1) Convert panel-damages script into a function that can fit production 
#    for any level of excess load (quickly)
#   - Scale annual output by hourly profile
#   - Calculate new excess load
#   - Use coefs to calc raw fitted value
#   - Draw epsilon and censor btw (0,namepap)
#   - Calculate damages
# 2) Figure out current level of 

# TODO: Enforce that production doesnt increase 

library(pacman)
p_load(
  here, fst, data.table, ggplot2, magrittr, janitor, stringr, purrr,
  collapse
)

# Loading data ----------------------------------------------------------------
  # Marginal production model
  model_dt = 
   fread(
     here("Data/electricity-generation/output/PPRegressors-oge.csv")
    )[,plant_id_eia := as.character(orispl_code)] |> 
    setkey(plant_id_eia)
  # Creating a long version (so rows are coefs, does not have intercept)
  model_dt_long = 
    melt(
      model_dt, 
      id.vars = 'plant_id_eia',
      measure.vars = patterns('excess_load_')
    )[value != 0,.(
      plant_id_eia, 
      nerc_adj = str_remove_all(variable, 'excess_load_|_sq'),
      square = str_detect(variable, '_sq$'),
      coef = value
    )] |>
    merge( # Adding the estimated excess load where max/min happens
      melt(
        model_dt,
        id.vars = 'plant_id_eia', 
        measure.vars = patterns('est_extremum')
      )[!is.na(value),.(
        plant_id_eia, 
        nerc_adj = str_remove(variable, 'est_extremum_'),
        est_extremum = value
      )],
      by = c('plant_id_eia','nerc_adj'),
      all.x = TRUE
    )
  # Some name cleaning 
  setnames(
    model_dt, 
    old = str_subset(colnames(model_dt), 'sq'),
    new = paste0(
      'excess_load_sq_',
      str_subset(colnames(model_dt), 'sq') |>
        str_remove('_sq') |>
        str_remove('excess_load_')
  ))
  setnames(
    model_dt, 
    old = str_subset(colnames(model_dt), 'excess_load'),
    new = paste0('coef_',str_subset(colnames(model_dt), 'excess_load'))
  )
  # Load data
  nerc_load_dt = 
    fread(
      here("Data/electricity-generation/output/RegLoad-oge.csv")
    )[,.(
      nerc_adj = tolower(nerc_adj), 
      datetime_utc = utc_time, 
      excess_load, 
      excess_load_sq = excess_load^2
    )]
  # Emissions rates
  emissions_rate_dt = 
    read.fst(
      path = here('Data/electricity-generation/output/emissions-md-dt.fst'),
      as.data.table = TRUE
    )[,plant_id_eia := as.character(plant_id_eia)] |>
    dcast(
      plant_id_eia + net_generation_mwh_split ~ output + pollutant,
      value.var = c('emissions_rate_tons')
    )
  # Marginal damages 
  tot_md_dt = fread(
    here('Data/electricity-generation/output/damage-per-mwh-oge.csv')
  )[,plant_id_eia := as.character(plant_id_eia)] 
  # Avg solar panel production by region (From Mark email 10/17/2022)
  panel_output_dt = data.table(
    nerc_adj = c('cal','mro','npcc','rfc','serc','tre','wecc'),	
    avg_output_mwh = c(5935.968, 4802.954, 4546.329, 4584.62, 5161.415, 5274.348, 5444.394)/1000
  )
  # Time profile for solar production by state
  state_solar_profile_dt = fread(
    here("Data/electricity-generation/output/solar-profile-dt.csv")
  )
  # Crosswalk between states and nerc regions 
  nerc_state_xwalk = data.table(
    nerc_adj = c('cal','mro','npcc','rfc','serc','tre','wecc'),
    state_fips = c(6, 29, 25, 39, 13, 48, 49)
  )
  # Plant info table
  plant_dt =
    rbind(
      read.fst(
        path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
        as.data.table = TRUE
      )[, 
        .SD[which.min(year)], 
        by = .(plant_id_eia = as.character(plant_id_eia))
      ],
      read.fst(
        path = here("Data/electricity-generation/shaped-info-dt-oge.fst"),
        as.data.table = TRUE
      )[, 
        .SD[which.min(year)], 
        by = .(plant_id_eia = as.character(plant_id_eia))
      ],
      fill = TRUE
    ) |>
    merge(model_dt, by = 'plant_id_eia')|>
    merge(emissions_rate_dt, by = 'plant_id_eia') |>
    merge(tot_md_dt, by = 'plant_id_eia') %>%
    .[,.(
      plant_id_eia, 
      namepcap, 
      log_scale, 
      intercept,
      net_generation_mwh_split = net_generation_mwh_split.x, 
      md_per_mwh_high, 
      md_per_mwh_low,
      emissions_rate_low_co2e,
      emissions_rate_high_co2e,
      emissions_rate_low_nox,
      emissions_rate_high_nox,
      emissions_rate_low_so2,
      emissions_rate_high_so2,
      emissions_rate_low_pm25,
      emissions_rate_high_pm25
    )] |>
    setkey(plant_id_eia)

# Determine the change in load for change in panels ---------------------------
  calc_adj_load = function(
    n_panels,
    state_solar_profile_dt, 
    nerc_state_xwalk, 
    panel_output_dt,
    nerc_load_dt
  ){
    # Getting panel profile 
    panel_load_dt = 
      merge(
        state_solar_profile_dt, 
        nerc_state_xwalk, 
        by = 'state_fips'
      ) |>
      merge(
        panel_output_dt,
        by = 'nerc_adj'
      ) %>% 
      .[,.(
        nerc_adj, 
        datetime_utc = utc_time, 
        n_panels = n_panels,
        panel_mwh = elec_profile*avg_output_mwh*n_panels
      )] 
    # Merging with existing load 
    nerc_panel_load_dt = 
      merge(
        panel_load_dt, 
        nerc_load_dt,
        by = c('datetime_utc','nerc_adj')
      )[,.(
        nerc_adj, 
        datetime_utc,
        n_panels,
        excess_load = excess_load - panel_mwh, 
        excess_load_sq = (excess_load - panel_mwh)^2
      )] |>
      melt(
        id.vars = c('nerc_adj','datetime_utc','n_panels'),
        measure.vars = patterns('excess_load')
      ) %>%
      .[,.(
        nerc_adj, 
        datetime_utc,
        n_panels,
        square = str_detect(variable, '_sq$'),
        excess_load = value
      )]
    # Returning the result 
    return(nerc_panel_load_dt)
  }

  # Function to create table of epsilons for each plant 
  generate_epsilons = function(plant_dt, datetime_utc, seed = 1234){
    N = length(datetime_utc)
    set.seed(seed)
    epsilon_dt = 
      map_dfr(
        plant_dt$plant_id_eia,
        \(id){
          data.table(
            plant_id_eia = as.character(id),
            datetime_utc = datetime_utc,
            epsilon = rnorm(
              N, mean = 0, 
              sd = exp(plant_dt[plant_id_eia == as.character(id)]$log_scale)
            )
          )
        }
      ) |> setkey(plant_id_eia, datetime_utc)
    return(epsilon_dt)
  }

# Calculating production ------------------------------------------------------
  calc_production = function(
    load_dt, 
    plant_dt, 
    model_dt_long,
    epsilon_dt
  ){
    # First merging together
    prod_dt = 
      merge(
        load_dt,
        model_dt_long,
        by = c('nerc_adj','square'),
        allow.cartesian = TRUE
      )
    # Summing across regions
    coef_prod_dt = 
      prod_dt[,.(
        plant_id_eia, 
        datetime_utc, 
        n_panels,
        fit_net_gen = fcase(
          est_extremum <= 0 | is.na(est_extremum), coef*excess_load, 
          est_extremum > 0 & excess_load <= est_extremum, coef*excess_load,
          est_extremum > 0 & excess_load > est_extremum, coef*est_extremum,
          default = NA
        )
      )] |>
      gby(plant_id_eia, datetime_utc, n_panels) |>
      fsum() |>
      setkey(plant_id_eia, datetime_utc)
    # Merging to intercept and log scale table 
    # Losing some plants here that are intercept only
    gen_dt = 
      merge(
        plant_dt,
        coef_prod_dt,
        by = 'plant_id_eia'
      ) |>
      merge(
        epsilon_dt, 
        by = c('plant_id_eia','datetime_utc')
      ) %>%
      .[, # First adding intercept
        fit_net_generation_mwh_raw := intercept + 
          fifelse(is.na(fit_net_gen), 0, fit_net_gen)
      ] %>%
      .[, # Now calculating fitted value (censored)
        fit_net_generation_mwh := fcase(
          fit_net_generation_mwh_raw + epsilon >= 0 
          & fit_net_generation_mwh_raw + epsilon <= namepcap, 
            fit_net_generation_mwh_raw + epsilon,
          fit_net_generation_mwh_raw + epsilon < 0, 0,
          fit_net_generation_mwh_raw + epsilon > namepcap, namepcap
        )
      ] %>%
      .[,':='( # Now calculating emissions and damages
        fit_emissions_tons_co2e = fifelse(
          fit_net_generation_mwh >= net_generation_mwh_split,
          (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_co2e
          + net_generation_mwh_split*emissions_rate_low_co2e,
          fit_net_generation_mwh*emissions_rate_low_co2e
        ),
        fit_emissions_tons_so2 = fifelse(
          fit_net_generation_mwh >= net_generation_mwh_split,
          (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_so2
          + net_generation_mwh_split*emissions_rate_low_so2,
          fit_net_generation_mwh*emissions_rate_low_so2
        ),
        fit_emissions_tons_nox = fifelse(
          fit_net_generation_mwh >= net_generation_mwh_split,
          (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_nox
          + net_generation_mwh_split*emissions_rate_low_nox,
          fit_net_generation_mwh*emissions_rate_low_nox
        ),
        fit_emissions_tons_pm25 = fifelse(
          fit_net_generation_mwh >= net_generation_mwh_split,
          (fit_net_generation_mwh-net_generation_mwh_split)*emissions_rate_high_pm25
          + net_generation_mwh_split*emissions_rate_low_pm25,
          fit_net_generation_mwh*emissions_rate_low_pm25
        ),
        damages = fifelse(
          fit_net_generation_mwh >= net_generation_mwh_split,
          (fit_net_generation_mwh-net_generation_mwh_split)*md_per_mwh_high
          + net_generation_mwh_split*md_per_mwh_low,
          fit_net_generation_mwh*md_per_mwh_low
        )
      )]
    # Summing everything together
    plant_gen_dt = 
      gen_dt[,.(
        plant_id_eia, n_panels, 
        fit_net_generation_mwh, 
        fit_emissions_tons_co2e,
        fit_emissions_tons_so2,
        fit_emissions_tons_nox,
        fit_emissions_tons_pm25,
        damages
      )] |>
      gby(plant_id_eia, n_panels) |>
      fsum()
    rm(gen_dt, prod_dt, coef_prod_dt)
    return(plant_gen_dt)
  }

# Runnning everything (15 seconds currently)
epsilon_dt = 
  generate_epsilons(
    plant_dt, 
    nerc_load_dt[nerc_adj == 'cal']$datetime_utc
  )

run_panel_calcs = function(n_panels){
  load_dt = calc_adj_load(
    n_panels = n_panels,
    state_solar_profile_dt, 
    nerc_state_xwalk, 
    panel_output_dt,
    nerc_load_dt
  )
  plant_gen_dt = calc_production(
    load_dt,
    plant_dt, 
    model_dt_long,
    epsilon_dt
  )
  return(plant_gen_dt)
}

p_load(purrr)
plant_gen_dt = 
  map_dfr(
    seq(0, 2e9, by = 2e9/100) |> round(digits = 0),
    run_panel_calcs
  )
write.fst(
  plant_gen_dt, 
  path= here('Data/electricity-generation/output/panel-simulations.fst')
)

# Now we get to try to figure out what is going on ----------------------------
plant_gen_dt = 
  read.fst(
    path= here('Data/electricity-generation/output/panel-simulations.fst'),
    as.data.table = TRUE
  )
# First looking at the percent of total electricity being produce
plant_gen_dt[,
  .(tot_gen = sum(fit_net_generation_mwh)), 
  keyby = n_panels
]


  plant_info_dt =
    rbind(
      read.fst(
        path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
        as.data.table = TRUE
      )[, 
        .SD[which.min(year)], 
        by = .(plant_id_eia = as.character(plant_id_eia))
      ],
      read.fst(
        path = here("Data/electricity-generation/shaped-info-dt-oge.fst"),
        as.data.table = TRUE
      )[, 
        .SD[which.min(year)], 
        by = .(plant_id_eia = as.character(plant_id_eia))
      ],
      fill = TRUE
    )[,.(plant_id_eia, fuel_category, nerc_adj)]



  # Summarizing by region and fuel type 
  region_sim_dt = 
    merge(
      plant_gen_dt,
      plant_info_dt, 
      by = 'plant_id_eia'
    )[,.(
      tot_gen_mwh = sum(fit_net_generation_mwh),
      tot_co2e_tons = sum(fit_emissions_tons_co2e),
      tot_nox_tons = sum(fit_emissions_tons_nox),
      tot_so2_tons = sum(fit_emissions_tons_so2),
      tot_pm25_tons = sum(fit_emissions_tons_pm25),
      tot_damages = sum(damages),
      co2e_tons_per_mwh = sum(fit_emissions_tons_co2e)/sum(fit_net_generation_mwh),
      nox_tons_per_mwh  = sum(fit_emissions_tons_nox)/sum(fit_net_generation_mwh),
      so2_tons_per_mwh  = sum(fit_emissions_tons_so2)/sum(fit_net_generation_mwh),
      pm25_tons_per_mwh = sum(fit_emissions_tons_pm25)/sum(fit_net_generation_mwh),
      damages_per_mwh = sum(damages)/sum(fit_net_generation_mwh)), 
      keyby = .(nerc_adj, n_panels)
    ]

  p_load(ggplot2)
  theme_set(theme_minimal())
  
  ggplot(region_sim_dt, aes(x = n_panels, y = tot_gen_mwh, color = nerc_adj)) + 
  geom_line()
    
  ggplot(region_sim_dt, aes(x =n_panels/1e9, y = tot_damages/1e6, color = nerc_adj)) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    labs(
      x = 'Number of Panels (Billions)',
      y = 'Damages ($1M)'
    )
  
  ggplot(region_sim_dt, aes(x =n_panels/1e9, y = tot_damages/1e6, color = nerc_adj)) + 
    geom_line(size = 1.25) +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    labs(
      x = 'Number of Panels (Billions)',
      y = 'Damages ($1M)'
    )
  # Looking at differences between pollutants 
  region_sim_dt |>
    melt(
      id.vars = c('nerc_adj','n_panels'),
      measure.vars = patterns('tons$')
    ) |>
    ggplot(aes(x = n_panels, y = value, color= nerc_adj)) + 
    geom_line() + 
    facet_wrap(~variable, scales ='free')

  region_sim_dt[,.(
    n_panels, 
    nerc_adj,
    tot_damages,
    md = c(NA, diff(tot_damages))/c(NA, diff(n_panels)),
    delta_damages = c(NA, diff(tot_damages)),
    delta_gen = c(NA, diff(tot_gen_mwh))
  )][n_panels != 0 & n_panels != 2e9] |>
    ggplot(aes(x =n_panels/1e6, y = md, color = nerc_adj)) + 
    #geom_line(size = 1.25) +
    #geom_point(alpha = 0.3) +
    geom_smooth(se = FALSE, method = 'loess') +
    scale_color_brewer(name = 'Region',palette = 'Dark2') + 
    labs(
      x = 'Number of Panels (Millions)',
      y = 'Marginal Damages ($/panel)'
    )

  
    

    
